提出十字交叉注意力模块,使用循环稀疏连接代替密集连接,实现性能SOTA
论文名称:CCNet: Criss-Cross Attention for Semantic Segmentation
作者:Zilong Huang,Xinggang Wang Yun,chao Wei,Lichao Huang,Wenyu Liu,Thomas S. Huang
摘要
上下文信息在视觉理解问题中至关重要,譬如语义分割和目标检测;
本文提出了一种十字交叉的网络(Criss-Cross Net)以非常高效的方式获取完整的图像上下文信息:
- 对每个像素使用一个十字注意力模块聚集其路径上所有像素的上下文信息;
- 通过循环操作,每个像素最终都可以捕获完整的图像相关性;
- 提出了一种类别一致性损失来增强模块的表现。
CCNet具有一下优势:
- 显存友好:相较于Non-Local减少显存占用11倍
- 计算高效:循环十字注意力减少Non-Local约85%的计算量
- SOTA
- Achieve the mIoU scores of 81.9%, 45.76% and 55.47% on the Cityscapes test set, the ADE20K validation set and the LIP validation set respectively
介绍
- 当前FCN在语义分割任务取得了显著进展,但是由于固定的几何结构,分割精度局限于FCN局部感受野所能提供的短程感受野,目前已有相当多的工作致力于弥补FCN的不足,相关工作看论文。
- 密集预测任务实际上需要高分辨率的特征映射,因此Non-Local的方法往往计算复杂度高,并且占用大量显存,因此设想使用几个连续的稀疏连通图(sparsely-connected graphs)来替换常见的单个密集连通图( densely-connected graph),提出CCNet使用稀疏连接来代替Non-Local的密集连接。
- 为了推动循环十字注意力学习更多的特征,引入了类别一致损失(category consistent loss)来增强CCNet,其强制网络将每个像素映射到特征空间的n维向量,使属于同一类别的像素的特征向量靠得更近。
方法
CCNet可能是受到之前将卷积运算分解为水平和垂直的GCN以及建模全局依赖性的Non-Local,CCNet使用的十字注意力相较于分解更具优势,拥有比Non-Local小的多得计算量。
文中认为CCNet是一种图神经网络,特征图中的每个像素都可以被视作一个节点,利用节点间的关系(上下文信息)来生成更好的节点特征。
最后,提出了同时利用时间和空间上下文信息的3D十字注意模块。
网络结构
整体流程如下:
- 对于给定的,使用卷积层获得降维的特征映射;
- 会输入十字注意力模块以生成新的特征映射,其中每个像素都聚集了垂直和水平方向的信息;
- 进行一次循环,将输入十字注意力,得到,其中每个像素实际上都聚集了所有像素的信息;
- 将与局部特征表示进行;
- 由后续的网络进行分割。
Criss-Cross Attention
主要流程如下:
使用卷积进行降维得到;
通过Affinity操作生成注意力图,其中:
对于空间维度上的的每一个位置,我们可以得到一个向量;
同时,我们在上得到一个集合,其代表着位置的同一行或同一列;
令表示的第个元素,Affinity操作可以表示为:
其用来表示两者之间的相关性,最终我们可以得到
最终在通道维度上对使用,即可得到注意力图,需要注意的是,这里的通道维度代表的是这个维度,其表示某个位置像素与其垂直水平方向上像素的相关性。
另一方面,依旧使用卷积生成,我们可以获得一个向量和一个集合
最后使用Aggregation操作得到最终的特征图,其定义为:
其中是某个位置的特征向量。
至此,我们已经能够捕获某个位置像素水平和垂直方向上的信息,然而,该像素与周围的其他像素仍然不存在关系,为了解决这个问题,提出了循环机制。
Recurrent Criss-Cross Attention (RCCA)
通过多次使用CCA来达到对上下文进行建模,当循环次数R=2时,特征图中任意两个空间位置的关系可以定义为:
方便起见,对于特征图上的两个位置和,其信息传递示意图如下:
可以看到,经过两次循环,原本不相关的位置也能够建立联系了。
Learning Category Consistent Features
对于语义分割任务,属于同一类别的像素应该具有相似的特征,而来自不同类别的像素应该具有相距很远的特征。
然而,聚集的特征可能存在过度平滑的问题,这是图神经网络中的一个常见问题,为此,提出了类别一致损失。
其中的距离函数设计为分段形式,公式如下:
本文中,距离阈值的设置为
为了加速计算,对RCCA的输入进行降维,其比率设置为16
总的损失函数定义如下:
本文中,的值分别为1,1,0.001,
3D Criss-Cross Attention
在2D注意力的基础上进行推广,提出3DCCA,其可以在时间维度上收集额外的上下文信息
其流程与2DCCA大致相同,具体细节差异看论文。
代码复现
Criss-Cross Attention
def INF(B,H,W):
# tensor -> torch.size([H]) -> 对角矩阵[H,H] -> [B*W,H,H]
# 消除重复计算自身的影响
return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H),0).unsqueeze(0).repeat(B*W,1,1)
class CrissCrossAttention(nn.Module):
""" Criss-Cross Attention Module"""
def __init__(self, in_ch,ratio=8):
super(CrissCrossAttention,self).__init__()
self.q = nn.Conv2d(in_ch, in_ch//ratio, 1)
self.k = nn.Conv2d(in_ch, in_ch//ratio, 1)
self.v = nn.Conv2d(in_ch, in_ch, 1)
self.softmax = nn.Softmax(3)
self.INF = INF
self.gamma = nn.Parameter(torch.zeros(1)) # 初始化为0
def forward(self, x):
bs, _, h, w = x.size()
# Q
x_q = self.q(x)
# b,c',h,w -> b,w,c',h -> b*w,c',h -> b*w,h,c'
# 后两维相当于论文中的Q_u
x_q_H = x_q.permute(0,3,1,2).contiguous().view(bs*w,-1,h).permute(0, 2, 1)
# b,c',h,w -> b,h,c',w -> b*h,c',w -> b*h,w,c'
x_q_W = x_q.permute(0,2,1,3).contiguous().view(bs*h,-1,w).permute(0, 2, 1)
# K
x_k = self.k(x) # b,c',h,w
# b,c',h,w -> b,w,c',h -> b*w,c',h
x_k_H = x_k.permute(0,3,1,2).contiguous().view(bs*w,-1,h)
# b,c',h,w -> b,h,c',w -> b*h,c',w
x_k_W = x_k.permute(0,2,1,3).contiguous().view(bs*h,-1,w)
# V
x_v = self.v(x)
# b,c,h,w -> b,w,c,h -> b*w,c,h
x_v_H = x_v.permute(0,3,1,2).contiguous().view(bs*w,-1,h)
# b,c,h,w -> b,h,c,w -> b*h,c,w
x_v_W = x_v.permute(0,2,1,3).contiguous().view(bs*h,-1,w)
# torch.bmm计算三维的矩阵乘法,如[bs,a,b][bs,b,c]
# 先计算所有Q_u和K上与位置u同一列的
energy_H = (torch.bmm(x_q_H, x_k_H)+self.INF(bs, h, w)).view(bs,w,h,h).permute(0,2,1,3) # b,h,w,h
# 再计算行
energy_W = torch.bmm(x_q_W, x_k_W).view(bs,h,w,w)
# 得到注意力图
concate = self.softmax(torch.cat([energy_H, energy_W], 3)) # b,h,w,h+w
# 后面开始合成一张图
att_H = concate[:,:,:,0:h].permute(0,2,1,3).contiguous().view(bs*w,h,h)
#print(concate)
#print(att_H)
att_W = concate[:,:,:,h:h+w].contiguous().view(bs*h,w,w)
# 同样的计算方法
out_H = torch.bmm(x_v_H, att_H.permute(0, 2, 1)).view(bs,w,-1,h).permute(0,2,3,1) # b,c,h,w
out_W = torch.bmm(x_v_W, att_W.permute(0, 2, 1)).view(bs,h,-1,w).permute(0,2,1,3) # b,c,h,w
#print(out_H.size(),out_W.size())
return self.gamma*(out_H + out_W) + x # 乘积使得整体可训练
Category Consistent Loss
实验
在Cityscapes、ADE20K、COCO、LIP和CamVid数据集上进行了实验,在一些数据集上实现了SOTA,并且在Cityscapes数据集上进行了消融实验。
实验结果
在Cityscapes上的结果:
消融实验
RCCA模块
通过改变循环次数进行了如下实验:
可以看到,RCCA模块可以有效的聚集全局上下文信息,同时保持较低的计算量。
为了进一步验证CCA的有效性,进行了定性比较:
随着循环次数的增加,这些白色圈圈区域的预测逐渐得到纠正,这证明了密集上下文信息在语义分割中的有效性。
类别一致损失
上图中的CCL即表示使用了类别一致损失
上述结果表明了分段距离和类别一致损失的有效性。
对比其他聚集上下文信息方法
同时,对Non Local使用了循环操作,可以看到,循环操作带来了超过一点的增益,然而其巨量的计算量和显存需求限制性能
可视化注意力图
上图中可以看到循环操作的有效性。
更多实验
在ADE20K上的实验验证了类别一致损失(CCL)的有效性:
在LIP数据集的实验结果:
在COCO数据集的实验结果:
在CamVid数据上的实验结果: